#!/usr/bin/env python
# -*- coding:utf-8 -*-


'''
@Author : shouke
'''

import logging
import random
from copy import deepcopy

# import gevent
from locust import task
from locust.env import Environment
# from locust.stats import stats_printer
from locust.log import setup_logging
from locust.contrib.fasthttp import FastHttpUser

from config.config import log_level, slave_bind_host, master_bind_port
from components.decoraters.action_decorater import ActionDecorator

setup_logging(log_level, None)
logger = logging.getLogger(__name__)

class TearDownException(Exception):
    pass

class User(FastHttpUser):
    chain_weight_list = []
    iteration_mode_func_map = {1: 'get_all_chain_ids',
                               2: 'get_chain_id_by_weight',
                               3: 'get_chain_id_by_weight_percent_random',
                              }

    def __init__(self, environment):
        super().__init__(environment)
        self.client_id = id(self.client)
        self.chain_id = None # 记录当前用户正在执行的链路id
        self.chain_id_list = [] # 记录用户执行过的链路id
        self.chain_list = [] # 存放用户执行过的链路
        self.user_resource_dict = {} # 存放用户独享资源
        self.user_resource_dict['iteration_count'] = 0 # 初始化用户迭代计数器


        # 给用户增加组件action处理功能
        for action, action_map in ActionDecorator.ACTION_FUNC_CLASS_MODULE_MAP.items():
            module = __import__(action_map.get('module_path'), fromlist=['True'])
            class_cls = getattr(module, action_map.get('class_name'))
            setattr(self, action, getattr(class_cls, action_map.get('function_name')))


    def get_chain_id_by_weight(self):
        '''按链路权重获取待执行链路id'''

        if not User.chain_weight_list:
            User.chain_weight_list = deepcopy(self.environment.runner.scenario_chain_weight_list)
        max_chain_weight = max(User.chain_weight_list)
        if max_chain_weight > 0:
            index = User.chain_weight_list.index(max_chain_weight)
            User.chain_weight_list[index] -= 1
            return self.environment.runner.scenario_chain_id_list[index]
        elif max_chain_weight == 0 and min(User.chain_weight_list)==0:
            User.chain_weight_list = deepcopy(self.environment.runner.scenario_chain_weight_list)
            index = User.chain_weight_list.index(max(User.chain_weight_list))
            User.chain_weight_list[index] -= 1
            return self.environment.runner.scenario_chain_id_list[index]

    def get_chain_id_by_weight_percent_random(self, sample_rate_list=[]):
        '''按链路权重占比随机获取待执行链路id'''

        if not sample_rate_list:
            sample_rate_list = self.environment.runner.chain_weight_percent_list

        if sum([item[1] for item in sample_rate_list]) != 1:
            raise ValueError("样本比例配置错误，样本占比之和必须为1!")

        random_normalized_num = random.random()  # random() -> x in the interval [0, 1).
        accumulated_probability = 0.0
        for sample, probabilitie in sample_rate_list:
            accumulated_probability += probabilitie
            if random_normalized_num < accumulated_probability:
                return sample


    def get_all_chain_ids(self):
        '''获取所有链路id'''

        return self.environment.runner.scenario_chain_id_list


    def on_start(self):
        pass


    def run_actions(self, actions):
        ''' 执行一系列动作 '''
        for step in actions:
            action = step.get('action')
            if not action:
                continue
            if action.lower() == 'teardown':
                continue
            logger.debug('【执行步骤】：%s'% step['name'])
            logger.debug('【步骤配置】：%s' % str(step))

            file_data_dict = self.environment.runner.file_data_dict
            user_share_resource_dict = self.environment.runner.user_share_resouce_dict
            try:
                if hasattr(self, action):
                    getattr(self, action)(step, user=self, chain_id=self.chain_id, file_data_dict=file_data_dict, user_share_resource_dict=user_share_resource_dict)
                    logger.debug('【步骤执行结果】成功, 步骤名称：%s 步骤action：%s' % (step.get('name'), action))
                else:
                    logger.error('【步骤执行结果】失败，步骤名称：%s 步骤action：%s, 失败原因：未识别的 action' % (step.get('name'), action))
                    raise Exception('未识别的 action：%s' % action)
            except TearDownException as e:
                logger.error('【步骤执行结果】失败，步骤名称：%s 步骤action：%s, 失败原因：运行出错：%s' % (step.get('name'), action, e))
                raise
            except Exception as e:
                logger.error('【步骤执行结果】失败，步骤名称：%s 步骤action：%s, 失败原因：运行出错：%s' % (step.get('name'), action, e))
                raise

    @task
    def locust_load_test_task(self):
        try:
            chain_ids = getattr(self, User.iteration_mode_func_map[self.environment.runner.iteration_mode])()
            if type(chain_ids) != type([]):
                chain_ids = [chain_ids]

            for chain_id in chain_ids:
                try:
                    self.chain_id = chain_id
                    chain = self.environment.runner.target_scenario[self.chain_id]
                    chain_actions = chain['chain']
                    if self.chain_id not in self.chain_id_list:
                        self.chain_id_list.append(self.chain_id)
                        self.chain_list.append(chain)

                    if self.chain_id not in self.user_resource_dict:
                        self.user_resource_dict[self.chain_id] = {}
                        self.environment.runner.user_share_resouce_dict[self.chain_id] = {}

                    if self.client_id not in self.user_resource_dict[self.chain_id]:
                        self.user_resource_dict[self.chain_id]={}
                        self.user_resource_dict[self.chain_id]['iteration_count'] = 0 # 初始化迭代计数器
                        self.user_resource_dict[self.chain_id]['func_exec_times'] = {} # 用于记录仅执行一次的函数的运行次数
                        self.user_resource_dict[self.chain_id]['teardown_tasks'] = self.environment.runner.target_scenario[self.chain_id].get('teardown') or [] # 记录用户teardown任务

                    self.run_actions(chain_actions)
                except TearDownException as e:
                    raise
                except Exception as e:
                    logger.info('运行出错：%s' % e)
                    raise
                finally:
                    for chain in self.chain_list:
                        for item in chain.get('chain'):
                            if 'teardown' == str(item.get('action')).lower():
                                logger.debug('正在执行teardown操作')
                                self.teardown(item)
                        self.user_resource_dict[chain.get('id')]['iteration_count'] += 1
        except Exception as e:
            raise e
        finally:
            self.user_resource_dict['iteration_count'] += 1

    def teardown(self, step, *args, **kwargs):
        ''' 执行teardown操作 '''
        try:
            # if step['id'] not in self.user_resource_dict[self.chain_id]['func_exec_times']:
            #     self.user_resource_dict[self.chain_id]['func_exec_times'][step['id']] = 1

            if step.get('children'):
                self.run_actions(step['children'])
        except Exception as e:
            raise TearDownException('%s' % e)

    def on_stop(self):
        try:
            for chain_id in self.chain_id_list:
                self.chain_id = chain_id # 更新当前执行action所在链路
                self.run_actions(self.user_resource_dict[chain_id]['teardown_tasks'])
        except Exception as e:
            logger.error('%s' % e)
        finally:
            # 重置资源
            self.user_resource_dict = {}


if __name__ == '__main__':
    # 设置环境
    env = Environment(user_classes=[User])
    env.create_worker_runner(slave_bind_host, master_bind_port) # 为上述环境创建worker运行器

    # 启动一个greenlet(协程)用于周期性的输出当前性能统计数据
    # gevent.spawn(stats_printer(env.stats)) #LocustPlus内部实现了自动打印器，所以注释掉

    # 等待greenlets全部结束
    env.runner.greenlet.join()



